- Published on
算法005-线段树(优化-矩阵中的局部最大值 II)
- Authors

- Name
- i Joe
线段树
线段树用于查询区间最大最小值,对于一维数组在预处理时时间复杂度为O(nlogn),空间复杂度也是O(nlogn)。下面表格对常用算法作比较。
| 数据结构 | 主要作用 | 查询 | 修改 | 建立 | 空间 | 特点 |
|---|---|---|---|---|---|---|
| 树状数组 | 单点修改+区间和 | O(logn) | O(logn) | O(n) | O(n) | 简洁 |
| ST表 | 静态区间最值 | O(1) | 不支持 | O(nlogn) | O(nlogn) | 查询最快 |
| 线段树 | 动态区间问题 | O(logn) | O(logn) | O(n) | O(4n) | 万能 |
原理
原理很简单,就是将数组区间以树方式呈现。比如[l, r]区间的节点包含它的子区间,也就是对应的左右子树[l, mid]和[mid + 1, r],mid是中间值。这样建立的树也是完全二叉树,所以查询修改是O(logn)。
代码
#include <bits/stdc++.h>
using namespace std;
class SegmentTree {
private:
int n;
vector<long long> tree;
void build(const vector<int>& a, int node, int l, int r) {
if (l == r) {
tree[node] = a[l];
return;
}
int mid = (l + r) / 2;
build(a, node * 2, l, mid);
build(a, node * 2 + 1, mid + 1, r);
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
void update(int node, int l, int r, int index, int value) {
if (l == r) {
tree[node] = value;
return;
}
int mid = (l + r) / 2;
if (index <= mid) {
update(node * 2, l, mid, index, value);
} else {
update(node * 2 + 1, mid + 1, r, index, value);
}
tree[node] = tree[node * 2] + tree[node * 2 + 1];
}
long long query(int node, int l, int r, int ql, int qr) {
// 当前区间完全被查询区间包含
if (ql <= l && r <= qr) {
return tree[node];
}
int mid = (l + r) / 2;
long long ans = 0;
// 查询区间和左子树有交集
if (ql <= mid) {
ans += query(node * 2, l, mid, ql, qr);
}
// 查询区间和右子树有交集
if (qr > mid) {
ans += query(node * 2 + 1, mid + 1, r, ql, qr);
}
return ans;
}
public:
SegmentTree(const vector<int>& a) {
n = a.size();
tree.assign(4 * n, 0);
build(a, 1, 0, n - 1);
}
// 单点修改:把 a[index] 改成 value
void update(int index, int value) {
update(1, 0, n - 1, index, value);
}
// 查询闭区间 [l, r] 的和
long long query(int l, int r) {
return query(1, 0, n - 1, l, r);
}
};
int main() {
vector<int> a = {2, 1, 5, 3, 4};
SegmentTree seg(a);
cout << seg.query(1, 3) << endl;
// 查询 a[1] + a[2] + a[3] = 1 + 5 + 3 = 9
seg.update(2, 10);
// 把 a[2] 从 5 改成 10
cout << seg.query(1, 3) << endl;
// 查询 a[1] + a[2] + a[3] = 1 + 10 + 3 = 14
return 0;
}
矩阵中的局部最大值 II
给你一个 n x m 的整数矩阵 matrix ,所有元素均为非负整数。
一个 非零 单元格 (row, col) 会按如下方式检查其附近的单元格:
1.令 x = matrix[row][col] 。
2.考虑在 (row, col) 的 x 行和 x 列范围内的每个单元格。
3.忽略矩阵外的单元格。
4.忽略行距离和列距离都恰好等于 x 的 单元格。
如果单元格 (row, col) 是 非零 的,并且所有考虑的单元格中没有一个值 大于 x ,那么该单元格就是一个 局部最大值 。
返回一个整数,表示 matrix 中 局部最大值 的数量。
- 示例 1:
输入:matrix = [[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,2,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0]]
输出: 1
解释:
对于非零单元格 (3, 3) ,x = matrix[3][3] = 2 。
高亮的单元格是在 (3, 3) 的 x 行和 x 列范围内被考虑的单元格。
行距离和列距离都等于 x = 2 的四个单元格被忽略。
没有一个被考虑的单元格的值大于 2 ,因此 (3, 3) 是一个局部最大值。
没有其他非零单元格,所以答案是 1 。
- 示例 2:
输入: matrix = [[1,2],[3,4]]
输出: 1
解释:
只有值为 4 的单元格是局部最大值。其他每个非零单元格都考虑到了一个具有更大值的单元格。
- 示例 3:
输入:matrix = [[1,0,1],[0,1,0],[1,0,1]]
输出:5
解释:
对于值为 1 的单元格,考虑的单元格是其自身及其在矩阵内的 4 个方向上相邻的单元格。
这五个值为 1 的单元格中,每一个都只考虑到值为 0 或 1 的单元格,所以这五个单元格都是局部最大值。
- 示例 4:
输入:matrix = [[1,1],[1,1]]
输出:4
解释:
所有单元格都具有相同的值。因此,没有任何一个单元格会考虑到具有更大值的其他单元格,所以所有 4 个单元格都是局部最大值。
提示:
1 <= n == matrix.length <= 2001 <= m == matrix[i].length <= 2000 <= matrix[i][j] <= 200
解题思路
线段树+st表优化的核心思路是,将列用st表构建,行用线段树构建,即时间复杂度为O(n*mlogm)。随后再通过查询,其中线段树时间复杂度是O(mnlogn),因此整体时间复杂度为O(mn(logn+logm))。
在上述基础上,有两种不同构建方式。这里我把最基本的单列表构建的st表称为基础st表。
- 基于构建的整个基础st表来构建线段树,即:任意
st[i][j]都可以构建出一个列范围为[i, i + 2^j -1]的表。 - 构建线段树时,只依托于子节点的
st[i][0]生成一个新的列,然后再通过st构建出全新的表。
上面两个其实时间复杂度都一样,但第一个代码实现上较复杂,所以选用第二个。
代码
struct STTable {
public:
vector<vector<int>> st;
STTable() {}
STTable(const vector<int>& table) {
int n = table.size();
int kn = bit_width((unsigned)n);
st.resize(n, vector<int>(kn));
for (int i = 0; i < n; i++) {
st[i][0] = table[i];
}
for (int ki = 1; ki < kn; ki++) {
for (int i = 0; i <= n - (1 << ki); i++) {
st[i][ki] = max(st[i][ki - 1], st[i + (1 << (ki - 1))][ki - 1]);
}
}
}
int query(int l, int r) {
int k = bit_width(1u * (r - l)) - 1;
return max(st[l][k], st[r - (1 << k)][k]);
}
};
struct SegmentTree {
public:
vector<STTable> t;
void build(const vector<vector<int>>& a, int node, int l, int r) {
if (l == r) {
t[node] = STTable(a[l]); // 这里时对叶节点进行构建
return;
}
int mid = (l + r) / 2;
build(a, node * 2, l, mid);
build(a, node * 2 + 1, mid + 1, r);
vector<int> merge(a[0].size());
// 这里是先基于st[i][0]生成一个全新的列
for (int i = 0; i < a[0].size(); i++) {
merge[i] = max(t[node * 2].st[i][0], t[node * 2 + 1].st[i][0]);
}
t[node] = STTable(merge); // 然后再将其列构建一个全新的st表
}
SegmentTree(const vector<vector<int>>& a) : t(2 << bit_width(a.size())) {
build(a, 1, 0, a.size() - 1);
}
int query(int node, int l, int r, int r1, int r2, int c1, int c2) {
if (r1 <= l && r <= r2) {
return t[node].query(c1, c2);
}
int m = (l + r) / 2;
if (r2 <= m) {
return query(node * 2, l, m, r1, r2, c1, c2);
}
if (r1 > m) {
return query(node * 2 + 1, m + 1, r, r1, r2, c1, c2);
}
return max(query(node * 2, l, m, r1, r2, c1, c2), query(node * 2 + 1, m + 1, r, r1, r2, c1, c2));
}
};
class Solution {
public:
int countLocalMaximums(vector<vector<int>>& matrix) {
int n = matrix.size();
int m = matrix[0].size();
SegmentTree t(matrix);
int ans = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
int x = matrix[i][j];
if (x > 0 && max(t.query(1, 0, n - 1, max(i - x, 0), min(i + x, n - 1), max(j - x + 1, 0), min(j + x, m)),
t.query(1, 0, n - 1, max(i - x + 1, 0), min(i + x - 1, n - 1), max(j - x, 0), min(j + x + 1, m))) <= x)
ans++;
}
}
return ans;
}
};
